from enum import Enum
from typing import Dict, Tuple

import torch

__all__ = "MicroBatchManager"


class Status(Enum):
    PREFILL = 1
    GENERATE = 2
    DONE = 3
    COOLDOWN = 4


class MicroBatchDescription:
    """
    This is the class to record the infomation of each microbatch, and also do some update operation.
    This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
    details, please refer to the doc of these two classes blow.

    Args:
        inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
        output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
    """

    def __init__(
        self,
        inputs_dict: Dict[str, torch.Tensor],
        output_dict: Dict[str, torch.Tensor],
        new_length: int,
    ) -> None:
        assert output_dict.get("hidden_states") is not None
        self.mb_length = output_dict["hidden_states"].shape[-2]
        self.target_length = self.mb_length + new_length
        self.kv_cache = ()

    def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
        if output_dict is not None:
            self._update_kvcache(output_dict["past_key_values"])

    def _update_kvcache(self, kv_cache: Tuple):
        assert type(kv_cache) == tuple
        self.kv_cache = kv_cache

    @property
    def state(self):
        """
        Return the state of current micro batch, when current length is equal to target length,
        the state is DONE, otherwise GENERATE

        """
        # TODO: add the condition for early stopping
        if self.cur_length == self.target_length:
            return Status.DONE
        elif self.cur_length == self.target_length - 1:
            return Status.COOLDOWN
        else:
            return Status.GENERATE

    @property
    def cur_length(self):
        """
        Return the current sequnence length of micro batch

        """


class HeadMicroBatchDescription(MicroBatchDescription):
    """
    This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
    and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
    information and the condition to determine the state is different from other stages.

    Args:
        inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
        output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
        new_length (int): the new length of the input sequence.

    """

    def __init__(
        self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
    ) -> None:
        super().__init__(inputs_dict, output_dict, new_length)
        assert inputs_dict is not None
        assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
        self.input_ids = inputs_dict["input_ids"]
        self.attn_mask = inputs_dict["attention_mask"]
        self.new_tokens = None

    def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
        super().update(output_dict, new_token)
        if new_token is not None:
            self._update_newtokens(new_token)
        if self.state is not Status.DONE and new_token is not None:
            self._update_attnmask()

    def _update_newtokens(self, new_token: torch.Tensor):
        if self.new_tokens is None:
            self.new_tokens = new_token
        else:
            self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)

    def _update_attnmask(self):
        self.attn_mask = torch.cat(
            (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
        )

    @property
    def cur_length(self):
        """
        When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token

        """
        if self.new_tokens is None:
            return self.mb_length
        else:
            return self.mb_length + len(self.new_tokens[0])


class BodyMicroBatchDescription(MicroBatchDescription):
    """
    This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,

    Args:
        inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
        output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
    """

    def __init__(
        self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
    ) -> None:
        super().__init__(inputs_dict, output_dict, new_length)

    def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
        super().update(output_dict, new_token)

    @property
    def cur_length(self):
        """
        When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1

        """
        if len(self.kv_cache) == 0:
            return self.mb_length
        else:
            return self.kv_cache[0][0].shape[-2] + 1


class MicroBatchManager:
    """
    MicroBatchManager is a class that manages the micro batch.

    Args:
        stage (int): stage id of current stage.
        new_length (int): the new length of the input sequence.
        micro_batch_size (int): the micro batch size.
        micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.

    """

    def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
        self.stage = stage
        self.new_length = new_length
        self.micro_batch_size = micro_batch_size
        self.buffer_size = micro_batch_buffer_size
        self.mb_descrption_buffer = {}
        self.new_tokens_buffer = {}
        self.idx = 0

    def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
        """
        Update the state if microbatch manager, 2 conditions.
        1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
        2. For other conditon, only receive the output of previous stage, and update the descrption.

        Args:
            inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
            output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
            new_token (torch.Tensor): the new token generated by current stage.
        """
        # Add descrption first if the descrption is None
        if inputs_dict is None and output_dict is None and new_token is None:
            return Status.PREFILL
        if self.mb_descrption_buffer.get(self.idx) is None:
            self._add_descrption(inputs_dict, output_dict)
        self.cur_descrption.update(output_dict, new_token)
        return self.cur_state

    def export_new_tokens(self):
        new_tokens_list = []
        for i in self.mb_descrption_buffer.values():
            new_tokens_list.extend(i.new_tokens.tolist())
        return new_tokens_list

    def is_micro_batch_done(self):
        if len(self.mb_descrption_buffer) == 0:
            return False
        for mb in self.mb_descrption_buffer.values():
            if mb.state != Status.DONE:
                return False
        return True

    def clear(self):
        self.mb_descrption_buffer.clear()

    def next(self):
        self.idx = (self.idx + 1) % self.buffer_size

    def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]):
        if self.stage == 0:
            self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length)
        else:
            self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length)

    def _remove_descrption(self):
        self.mb_descrption_buffer.pop(self.idx)

    @property
    def cur_descrption(self) -> MicroBatchDescription:
        return self.mb_descrption_buffer.get(self.idx)

    @property
    def cur_kv_cache(self):
        if self.cur_descrption is None:
            return None
        return self.cur_descrption.kv_cache

    @property
    def cur_state(self):
        """
        Return the state of current micro batch, when current descrption is None, the state is PREFILL

        """
        if self.cur_descrption is None:
            return Status.PREFILL
        return self.cur_descrption.state
